package com.zhangc.limiter;

import java.io.IOException;
import java.io.PrintWriter;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;

import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.alibaba.fastjson.JSONObject;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;

/**
 * 〈基于jvm缓存的单机限流〉<br>
 * 1、根据资源url限流
 * 2、根据访问用户限流
 * 3、根据ip限流
 *
 * @author
 * @see [相关类/方法]（可选）
 * @since [产品/模块版本] （可选）
 */
public class RateLimiterFilter implements Filter {

    private static final long SCHEDULE_TIME = 60l;
    private static final long MAX_CACHE_SIZE = 10000l;
    private static final String WEB_RESOURSE = "webResource";
    private static final String WEB_USER = "webUser";
    private static final String WEB_IP = "webIp";
    private static final String CON_SWITCH = "controlSwitch";
    private static final String DO_CON = "1";
    private static final String TOKEN_STR = "操作频繁,请稍后再试";
    private static final Logger LOGGER = LoggerFactory.getLogger(RateLimiterFilter.class);
    private static Cache<String, AtomicInteger> accountCounterCache = CacheBuilder.newBuilder()
            .expireAfterWrite(SCHEDULE_TIME, TimeUnit.SECONDS).maximumSize(MAX_CACHE_SIZE).build();

    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException,
            ServletException {
        try {
            RateLimiterEnum rateEnum = checkToken(request);
            if (rateEnum == null) {
                chain.doFilter(request, response);
            } else {
                setResponse(response, TOKEN_STR);
            }
        } catch (Exception e) {
            LOGGER.error("阀值访问异常错误", e);
            chain.doFilter(request, response);
        }
        return;
    }

    /**
     * 功能描述: <br>
     * 〈检测访问的QPS量，是否超过阀值〉
     *
     * @param request request
     * @return com.zhangc.limiter.RateLimiterEnum
     * @author
     */
    public static RateLimiterEnum checkToken(ServletRequest request) {
        RateLimiterEnum rateEnum = null;
        // 判断流控开关是否开启
        String controlSwitch = FilterScmUtilHelper.getValueByKey(CON_SWITCH);
        if (!DO_CON.equals(controlSwitch)) {
            return rateEnum;
        }
        // 分别从web资源链接纬度、用户纬度、ip纬度进行阀指校验
        HttpServletRequest htpRequest = (HttpServletRequest) request;
        rateEnum = isWebMax(htpRequest);
        if (rateEnum != null) {
            setLog(((HttpServletRequest) request).getRequestURI() + rateEnum.value());
            return rateEnum;
        }
        rateEnum = isUserMax(htpRequest);
        if (rateEnum != null) {
            setLog(((HttpServletRequest) request).getRequestURI() + rateEnum.value());
            return rateEnum;
        }
        rateEnum = isIpMax(htpRequest);
        if (rateEnum != null) {
            setLog(((HttpServletRequest) request).getRequestURI() + rateEnum.value());
            return rateEnum;
        }
        return rateEnum;
    }

    /**
     * 功能描述: <br>
     * 〈判断web资源纬度是否达到阀指上限〉
     *
     * @param request request
     * @return com.zhangc.limiter.RateLimiterEnum
     * @author
     */
    public static RateLimiterEnum isWebMax(HttpServletRequest request) {
        RateLimiterEnum result = null;
        try {
            // 通过scm获取需要限流的web资源
            String webResource = FilterScmUtilHelper.getValueByKey(WEB_RESOURSE);
            // scm节点不存在，直接放行请求
            if (StringUtils.isBlank(webResource)) {
                return result;
            }
            // 获取阀值
            Long limitNum = getLimitNum(request.getRequestURI(), webResource);
            // 判断阀值是否达到上限
            if (isLimitFlag(request.getRequestURI(), limitNum)) {
                result = RateLimiterEnum.WEB;
            }
        } catch (Exception e) {
            LOGGER.error("判断web资源纬度是否达到阀指上限错误", e);
            return null;
        }
        return result;
    }

    /**
     * 功能描述: <br>
     * 〈判断用户纬度是否达到阀指上限〉
     *
     * @param request request
     * @return com.zhangc.limiter.RateLimiterEnum
     * @author
     */
    public static RateLimiterEnum isUserMax(HttpServletRequest request) {
        RateLimiterEnum result = null;
        try {
            String uri = request.getRequestURI();
            String userName = request.getRemoteUser();
            if (StringUtils.isBlank(userName)) {
                // 用户不存在直接放行
                return result;
            } else {
                // 通过SCM获取需要限流的用户
                String userResource = FilterScmUtilHelper.getValueByKey(WEB_USER);
                // SCM节点不存在，直接放行请求
                if (StringUtils.isBlank(userResource)) {
                    return result;
                }
                // 获取阀值
                Long limitNum = getLimitNum(request.getRequestURI(), userResource);
                // 判断阀值是否达到上限
                if (isLimitFlag(userName + uri, limitNum)) {
                    result = RateLimiterEnum.USER;
                }
            }
        } catch (Exception e) {
            LOGGER.error("判断用户纬度是否达到阀指上限错误", e);
            return null;
        }
        return result;
    }

    /**
     * 功能描述: <br>
     * 〈判断IP纬度是否达到阀指上限〉
     *
     * @param request request
     * @return com.zhangc.limiter.RateLimiterEnum
     * @author
     */
    public static RateLimiterEnum isIpMax(HttpServletRequest request) {
        RateLimiterEnum result = null;
        try {
            String uri = request.getRequestURI();
            String ip = request.getRemoteAddr();
            // 通过scm获取需要限流的IP
            String ipResource = FilterScmUtilHelper.getValueByKey(WEB_IP);
            // scm节点不存在，直接放行请求
            if (StringUtils.isBlank(ipResource)) {
                return result;
            }
            // 获取阀值
            Long limitNum = getLimitNum(request.getRequestURI(), ipResource);
            // 判断阀值是否达到上限
            if (isLimitFlag(ip + uri, limitNum)) {
                result = RateLimiterEnum.IP;
            }
        } catch (Exception e) {
            LOGGER.error("判断用户纬度是否达到阀指上限错误", e);
            return null;
        }
        return result;
    }

    /**
     * 功能描述: 解析阀值<br>
     * 〈功能详细描述〉
     *
     * @param uri
     * @param resource
     * @return
     * @see [相关类/方法](可选)
     * @since [产品/模块版本](可选)
     */
    public static Long getLimitNum(String uri, String resource) {
        Long limitNum = -1l;
        String[] webResourceArr = resource.split(";");
        if (ArrayUtils.isEmpty(webResourceArr)) {
            return limitNum;
        }
        for (String webTemp : webResourceArr) {
            if (StringUtils.isNotBlank(webTemp)) {
                String[] tempArr = webTemp.split("[*]");
                if (ArrayUtils.isNotEmpty(tempArr) && uri.indexOf(tempArr[0]) != -1 && tempArr.length == 2) {
                    limitNum = Long.valueOf(tempArr[1]);
                    break;
                }
            }
        }
        return limitNum;
    }

    /**
     * 从web资源纬度/用户纬度/IP纬度，判断访问是否达到阀值上限
     *
     * @param token
     * @param limitNum
     * @return
     */
    public static boolean isLimitFlag(String token, Long limitNum) {
        boolean flag = false;
        // -1产生场景：未匹配到拦截资源，直接放行
        if (limitNum == -1l) {
            return flag;
        }
        // 根据token查找，单位时间web资源访问次数
        AtomicInteger atomicInteger = accountCounterCache.getIfPresent(token);
        if (atomicInteger == null) {
            AtomicInteger atomicIntegerTemp = new AtomicInteger(0);
            AtomicInteger atomicIntegerNew = accountCounterCache.getIfPresent(token);
            if (atomicIntegerNew == null) {
                accountCounterCache.put(token, atomicIntegerTemp);
                atomicInteger = atomicIntegerTemp;
            } else {
                atomicInteger = atomicIntegerNew;
            }
        }
        // 判断QPS是否达到阀指
        if (atomicInteger.incrementAndGet() > limitNum) {
            flag = true;
        }
        return flag;
    }

    /**
     * 功能描述:设置返回信息 <br>
     * 〈功能详细描述〉
     *
     * @param response
     * @throws IOException
     * @see [相关类/方法](可选)
     * @since [产品/模块版本](可选)
     */
    public void setResponse(ServletResponse response, String rateObj) throws IOException {
        JSONObject responseJSONObject = new JSONObject();
        responseJSONObject.put("success", false);
        responseJSONObject.put("errorCode", "403");
        responseJSONObject.put("errorMsg", rateObj);
        response.setCharacterEncoding("UTF-8");
        response.setContentType("application/json; charset=utf-8");
        PrintWriter out = response.getWriter();
        out.append(responseJSONObject.toString());
    }

    public static Cache<String, AtomicInteger> getAccountCounterCache() {
        return accountCounterCache;
    }

    public static void setAccountCounterCache(String scheduleTimeNew, String maxCacheSizeNew) {
        accountCounterCache = CacheBuilder.newBuilder()
                .expireAfterWrite(Long.parseLong(scheduleTimeNew), TimeUnit.SECONDS)
                .maximumSize(Long.parseLong(maxCacheSizeNew)).build();

    }

    /**
     * 功能描述: 统一日志输出<br>
     * 〈功能详细描述〉
     *
     * @param logStr
     * @see [相关类/方法](可选)
     * @since [产品/模块版本](可选)
     */
    public static void setLog(String logStr) {
        if (LOGGER.isDebugEnabled()) {
            LOGGER.debug(logStr);
        }
    }

    @Override
    public void init(FilterConfig arg0) throws ServletException {
        LOGGER.debug("init");
    }

    @Override
    public void destroy() {
        LOGGER.debug("destroy");
    }

}
